import torch.nn as nn
import torch.nn.functional as F
import torch
from torch_geometric.nn import GCNConv,global_mean_pool
from torch_geometric.nn import GCNConv, GATConv,GATv2Conv, global_mean_pool, BatchNorm

class Enhanced_large_GCNv1(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim=17, num_heads=8, dropout=0.2):
        super(Enhanced_large_GCNv1, self).__init__()
        
        # 第一层图卷积（GCNConv）
        self.conv1 = GATConv(num_node_features, hidden_dim * 2)  # 增加隐藏维度
        self.bn1 = BatchNorm(hidden_dim * 2)
        
        # 第二层图卷积（GATConv）使用多头注意力
        self.conv2 = GATConv(hidden_dim * 2, hidden_dim * 2, heads=num_heads, dropout = 0)  # 增加头数
        self.bn2 = BatchNorm(hidden_dim * 2 * num_heads)
        
        # 第三层图卷积（GCNConv）
        self.conv3 = GATConv(hidden_dim * 2 * num_heads, hidden_dim * 4)  # 增加维度
        self.bn3 = BatchNorm(hidden_dim * 4)
        
        # 第四层图卷积（GATConv）使用多头注意力
        self.conv4 = GATConv(hidden_dim * 4, hidden_dim * 4, heads=num_heads, dropout=0)  # 增加头数
        self.bn4 = BatchNorm(hidden_dim * 4 * num_heads)

        self.conv5 = GATConv(hidden_dim * 4 * num_heads, hidden_dim * 8)  # 增加维度
        self.bn5 = BatchNorm(hidden_dim * 8)

        self.conv6 = GATConv(hidden_dim * 8, hidden_dim * 4)  # 增加维度
        self.bn6 = BatchNorm(hidden_dim * 4)
        
        # 变换维度的线性层，适应残差连接
        self.res_fc2 = nn.Linear(hidden_dim * 2, hidden_dim * 2 * num_heads)  # 变换维度（conv2残差）
        self.res_fc4 = nn.Linear(hidden_dim * 4, hidden_dim * 4 * num_heads)  # 变换维度（conv4残差）

        # 全局池化层
        self.global_pool = global_mean_pool

        # 定义 Dropout 模块
        self.dropout = nn.Dropout(dropout)
    
        self.fc1 = nn.Linear(hidden_dim * 4, output_dim)  # 增加全连接层

    def get_feat_modules(self):
        """定义获取特征模块的方法，返回包含图卷积层的 ModuleList"""
        feat_m = nn.ModuleList([])

        feat_m.append(self.conv1)  # 第一层图卷积
        feat_m.append(self.conv2)  # 第二层图卷积
        feat_m.append(self.conv3)  # 第三层图卷积
        feat_m.append(self.conv4)  # 第四层图卷积
        feat_m.append(self.conv5)  # 第五层图卷积
        feat_m.append(self.conv6)  # 第六层图卷积

        feat_m.append(self.fc1)  # 第一层全连接

        return feat_m
        
    def forward(self, data):
        """
        前向传播。

        参数:
        - data (Batch): 批处理后的图数据。

        返回:
        - out (torch.Tensor): 输出张量，形状为 [batch_size, 17]。
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch  # 获取节点特征、边索引和批处理信息
        
        # 第一层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第二层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout + 残差连接
        x_residual = x  # 保存原始输入用于残差连接
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)
        x_residual = self.res_fc2(x_residual)  # 变换残差维度
        x = x + x_residual  # 添加残差连接

        feat1 = x

        # 第三层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        # 第四层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout + 残差连接
        x_residual = x  # 保存原始输入用于残差连接
        x = self.conv4(x, edge_index)
        x = self.bn4(x)
        x = F.elu(x)
        x = self.dropout(x)
        x_residual = self.res_fc4(x_residual)  # 变换残差维度
        x = x + x_residual  # 添加残差连接

        feat2 = x

        x = self.conv5(x, edge_index)
        x = self.bn5(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv6(x, edge_index)
        x = self.bn6(x)
        x = F.relu(x)
        x = self.dropout(x)

        feat3 = x

        # 全局池化，将节点特征汇聚为图级特征
        x = self.global_pool(x, batch)  # 形状为 [batch_size, hidden_dim * 4 * num_heads]

        # 全连接层 1 + 激活函数 + Dropout
        x = self.fc1(x)
        out = torch.sigmoid(x)
        
        return [feat1, feat2, feat3], out

class Enhanced_large_GCNv2(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim=17, num_heads=8, dropout=0.2):
        super(Enhanced_large_GCNv2, self).__init__()
        
        # 第一层图卷积（GCNConv）
        self.conv1 = GCNConv(num_node_features, hidden_dim * 2)  # 增加隐藏维度
        self.bn1 = BatchNorm(hidden_dim * 2)
        
        # 第二层图卷积（GATConv）使用多头注意力
        self.conv2 = GCNConv(hidden_dim * 2, hidden_dim * 2 * num_heads)  # 增加头数
        self.bn2 = BatchNorm(hidden_dim * 2 * num_heads)
        
        # 第三层图卷积（GCNConv）
        self.conv3 = GCNConv(hidden_dim * 2 * num_heads, hidden_dim * 4)  # 增加维度
        self.bn3 = BatchNorm(hidden_dim * 4)
        
        # 第四层图卷积（GATConv）使用多头注意力
        self.conv4 = GCNConv(hidden_dim * 4, hidden_dim * 4 * num_heads)  # 增加头数
        self.bn4 = BatchNorm(hidden_dim * 4 * num_heads)

        self.conv5 = GCNConv(hidden_dim * 4 * num_heads, hidden_dim * 8)  # 增加维度
        self.bn5 = BatchNorm(hidden_dim * 8)

        self.conv6 = GCNConv(hidden_dim * 8, hidden_dim * 4)  # 增加维度
        self.bn6 = BatchNorm(hidden_dim * 4)
        
        # 变换维度的线性层，适应残差连接
        self.res_fc2 = nn.Linear(hidden_dim * 2, hidden_dim * 2 * num_heads)  # 变换维度（conv2残差）
        self.res_fc4 = nn.Linear(hidden_dim * 4, hidden_dim * 4 * num_heads)  # 变换维度（conv4残差）

        # 全局池化层
        self.global_pool = global_mean_pool

        # 定义 Dropout 模块
        self.dropout = nn.Dropout(dropout)
    
        self.fc1 = nn.Linear(hidden_dim * 4, output_dim)  # 增加全连接层

    def get_feat_modules(self):
        """定义获取特征模块的方法，返回包含图卷积层的 ModuleList"""
        feat_m = nn.ModuleList([])

        feat_m.append(self.conv1)  # 第一层图卷积
        feat_m.append(self.conv2)  # 第二层图卷积
        feat_m.append(self.conv3)  # 第三层图卷积
        feat_m.append(self.conv4)  # 第四层图卷积
        feat_m.append(self.conv5)  # 第五层图卷积
        feat_m.append(self.conv6)  # 第六层图卷积

        feat_m.append(self.fc1)  # 第一层全连接

        return feat_m
        
    def forward(self, data):
        """
        前向传播。

        参数:
        - data (Batch): 批处理后的图数据。

        返回:
        - out (torch.Tensor): 输出张量，形状为 [batch_size, 17]。
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch  # 获取节点特征、边索引和批处理信息
        
        # 第一层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第二层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout + 残差连接
        x_residual = x  # 保存原始输入用于残差连接
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)
        x_residual = self.res_fc2(x_residual)  # 变换残差维度
        x = x + x_residual  # 添加残差连接

        feat1 = x

        # 第三层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        # 第四层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout + 残差连接
        x_residual = x  # 保存原始输入用于残差连接
        x = self.conv4(x, edge_index)
        x = self.bn4(x)
        x = F.elu(x)
        x = self.dropout(x)
        x_residual = self.res_fc4(x_residual)  # 变换残差维度
        x = x + x_residual  # 添加残差连接

        feat2 = x

        x = self.conv5(x, edge_index)
        x = self.bn5(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv6(x, edge_index)
        x = self.bn6(x)
        x = F.relu(x)
        x = self.dropout(x)

        feat3 = x

        # 全局池化，将节点特征汇聚为图级特征
        x = self.global_pool(x, batch)  # 形状为 [batch_size, hidden_dim * 4 * num_heads]

        # 全连接层 1 + 激活函数 + Dropout
        x = self.fc1(x)
        out = torch.sigmoid(x)
        
        return [feat1, feat2, feat3], out   

    
    
class Enhanced_large_GCNv3(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim=17, num_heads=8, dropout=0.2):
        super(Enhanced_large_GCNv3, self).__init__()
        
        # 第一层图卷积（GCNConv）
        self.conv1 = GCNConv(num_node_features, hidden_dim * 2) # 增加隐藏维度
        self.bn1 = BatchNorm(hidden_dim * 2)
        
        # 第二层图卷积（GATConv）使用多头注意力
        self.conv2 = GATConv(hidden_dim * 2, hidden_dim * 2, heads=num_heads)  # 增加头数
        self.bn2 = BatchNorm(hidden_dim * 2 * num_heads)
        
        # 第三层图卷积（GCNConv）
        self.conv3 = GCNConv(hidden_dim * 2 * num_heads, hidden_dim * 4)  # 增加维度
        self.bn3 = BatchNorm(hidden_dim * 4)
        
        # 第四层图卷积（GATConv）使用多头注意力
        self.conv4 = GATConv(hidden_dim * 4, hidden_dim * 4, heads=num_heads)  # 增加头数
        self.bn4 = BatchNorm(hidden_dim * 4 * num_heads)

        self.conv5 = GCNConv(hidden_dim * 4 * num_heads, hidden_dim * 8)  # 增加维度
        self.bn5 = BatchNorm(hidden_dim * 8)

        self.conv6 = GCNConv(hidden_dim * 8, hidden_dim * 4)  # 增加维度
        self.bn6 = BatchNorm(hidden_dim * 4)
        
        # 变换维度的线性层，适应残差连接
        self.res_fc2 = nn.Linear(hidden_dim * 2, hidden_dim * 2 * num_heads)  # 变换维度（conv2残差）
        self.res_fc4 = nn.Linear(hidden_dim * 4, hidden_dim * 4 * num_heads)  # 变换维度（conv4残差）

        # 全局池化层
        self.global_pool = global_mean_pool

        # 定义 Dropout 模块
        self.dropout = nn.Dropout(dropout)
    
        self.fc1 = nn.Linear(hidden_dim * 4, output_dim)  # 增加全连接层

    def get_feat_modules(self):
        """定义获取特征模块的方法，返回包含图卷积层的 ModuleList"""
        feat_m = nn.ModuleList([])

        feat_m.append(self.conv1)  # 第一层图卷积
        feat_m.append(self.conv2)  # 第二层图卷积
        feat_m.append(self.conv3)  # 第三层图卷积
        feat_m.append(self.conv4)  # 第四层图卷积
        feat_m.append(self.conv5)  # 第五层图卷积
        feat_m.append(self.conv6)  # 第六层图卷积

        feat_m.append(self.fc1)  # 第一层全连接

        return feat_m
        
    def forward(self, data):
        """
        前向传播。

        参数:
        - data (Batch): 批处理后的图数据。

        返回:
        - out (torch.Tensor): 输出张量，形状为 [batch_size, 17]。
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch  # 获取节点特征、边索引和批处理信息

        # print(f"x device: {x.device}, edge_index device: {edge_index.device}")

        # print(f'x.shape:{x.shape};edge_index.shape:{edge_index.shape};batch.shape:{batch.shape}',flush=True)

        # if torch.isnan(x).any() or torch.isinf(x).any():
        #     print("x contains NaN or Inf!")
        # if torch.isnan(edge_index).any() or torch.isinf(edge_index).any():
        #     print("edge_index contains NaN or Inf!")
        # if edge_index.shape[1] == 0:
        #     print("Warning: empty edge_index!")
        # x.cuda()
        
        # 第一层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第二层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout + 残差连接
        x_residual = x  # 保存原始输入用于残差连接
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)
        x_residual = self.res_fc2(x_residual)  # 变换残差维度
        x = x + x_residual  # 添加残差连接

        feat1 = x

        # 第三层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        # 第四层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout + 残差连接
        x_residual = x  # 保存原始输入用于残差连接
        x = self.conv4(x, edge_index)
        x = self.bn4(x)
        x = F.elu(x)
        x = self.dropout(x)
        x_residual = self.res_fc4(x_residual)  # 变换残差维度
        x = x + x_residual  # 添加残差连接

        feat2 = x

        x = self.conv5(x, edge_index)
        x = self.bn5(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv6(x, edge_index)
        x = self.bn6(x)
        x = F.relu(x)
        x = self.dropout(x)

        feat3 = x

        # 全局池化，将节点特征汇聚为图级特征
        x = self.global_pool(x, batch)  # 形状为 [batch_size, hidden_dim * 4 * num_heads]

        # 全连接层 1 + 激活函数 + Dropout
        x = self.fc1(x)
        out = torch.sigmoid(x)
        
        return [feat1, feat2, feat3], out
    





class Simplified_GCN(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim=17, num_heads=8, dropout=0.2):
        super(Simplified_GCN, self).__init__()

        # 第一层图卷积（GCNConv）
        self.conv1 = GCNConv(num_node_features, hidden_dim * 2)
        self.bn1 = BatchNorm(hidden_dim * 2)

        # 第二层图卷积（GATConv）使用多头注意力
        self.conv2 = GATConv(hidden_dim * 2, hidden_dim * 2, heads=num_heads, dropout=dropout)
        self.bn2 = BatchNorm(hidden_dim * 2 * num_heads)

        # 第三层图卷积（GCNConv）
        self.conv3 = GCNConv(hidden_dim * 2 * num_heads, hidden_dim * 4)
        self.bn3 = BatchNorm(hidden_dim * 4)

        # 第四层图卷积（GCNConv）
        self.conv4 = GCNConv(hidden_dim * 4, hidden_dim)
        self.bn4 = BatchNorm(hidden_dim)

        # 全局池化层
        self.global_pool = global_mean_pool

        # 定义 Dropout 模块
        self.dropout = nn.Dropout(dropout)

        # 全连接层
        self.fc1 = nn.Linear(hidden_dim, output_dim)

    def get_feat_modules(self):
        """定义获取特征模块的方法，返回包含图卷积层的 ModuleList"""
        feat_m = nn.ModuleList([])

        feat_m.append(self.conv1)  # 第一层图卷积
        feat_m.append(self.conv2)  # 第二层图卷积
        feat_m.append(self.conv3)  # 第三层图卷积
        feat_m.append(self.conv4)  # 第四层图卷积

        feat_m.append(self.fc1)  # 第一层全连接

        return feat_m

    def forward(self, data):
        """
        前向传播。

        参数:
        - data (Batch): 批处理后的图数据。

        返回:
        - out (torch.Tensor): 输出张量，形状为 [batch_size, 17]。
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch  # 获取节点特征、边索引和批处理信息
        
        # 第一层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第二层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout
        feat1 = x  # 保留第一层特征
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)

        # 第三层图卷积 + 批量归一化 + 激活函数 + Dropout
        feat2 = x  # 保留第二层特征
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第四层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout
          # 保留第三层特征
        x = self.conv4(x, edge_index)
        x = self.bn4(x)
        x = F.elu(x)
        x = self.dropout(x)
        feat3 = x

        # 全局池化，将节点特征汇聚为图级特征
        x = self.global_pool(x, batch)  # 形状为 [batch_size, hidden_dim * 4 * num_heads]

        # 全连接层 1 + 激活函数 + Dropout
        x = self.fc1(x)
        out = torch.sigmoid(x)
        
        return [feat1, feat2, feat3], out



class Simplified_GCNv2(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim=17, num_heads=8, dropout=0.2):
        super(Simplified_GCNv2, self).__init__()

        # 第一层图卷积（GCNConv）
        self.conv1 = GCNConv(num_node_features, hidden_dim * 2)  # 维度与GAT时相同
        self.bn1 = BatchNorm(hidden_dim * 2)

        # 第二层图卷积（GCNConv）
        self.conv2 = GCNConv(hidden_dim * 2, hidden_dim * num_heads * 2)  # 维度与GAT时相同
        self.bn2 = BatchNorm(hidden_dim * num_heads * 2)

        # 第三层图卷积（GCNConv）
        self.conv3 = GCNConv(hidden_dim * num_heads * 2, hidden_dim * 4)  # 维度与GAT时相同
        self.bn3 = BatchNorm(hidden_dim * 4)

        # 第四层图卷积（GCNConv）
        self.conv4 = GCNConv(hidden_dim * 4, hidden_dim)  # 维度与GAT时相同
        self.bn4 = BatchNorm(hidden_dim)

        # 全局池化层
        self.global_pool = global_mean_pool

        # 定义 Dropout 模块
        self.dropout = nn.Dropout(dropout)

        # 全连接层
        self.fc1 = nn.Linear(hidden_dim, output_dim)  # 全连接层的输入维度要根据num_heads调整

    def get_feat_modules(self):
        """定义获取特征模块的方法，返回包含图卷积层的 ModuleList"""
        feat_m = nn.ModuleList([])

        feat_m.append(self.conv1)  # 第一层图卷积
        feat_m.append(self.conv2)  # 第二层图卷积
        feat_m.append(self.conv3)  # 第三层图卷积
        feat_m.append(self.conv4)  # 第四层图卷积

        feat_m.append(self.fc1)  # 第一层全连接

        return feat_m

    def forward(self, data):
        """
        前向传播。

        参数:
        - data (Batch): 批处理后的图数据。

        返回:
        - out (torch.Tensor): 输出张量，形状为 [batch_size, 17]。
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch  # 获取节点特征、边索引和批处理信息
        
        # 第一层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第二层图卷积 + 批量归一化 + 激活函数 + Dropout
        feat1 = x  # 保留第一层特征
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)

        # 第三层图卷积 + 批量归一化 + 激活函数 + Dropout
        feat2 = x  # 保留第二层特征
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第四层图卷积 + 批量归一化 + 激活函数 + Dropout
        feat3 = x  # 保留第三层特征
        x = self.conv4(x, edge_index)
        x = self.bn4(x)
        x = F.elu(x)
        x = self.dropout(x)

        # 全局池化，将节点特征汇聚为图级特征
        x = self.global_pool(x, batch)  # 形状为 [batch_size, hidden_dim * num_heads]

        # 全连接层 1 + 激活函数 + Dropout
        x = self.fc1(x)
        out = torch.sigmoid(x)
        
        return [feat1, feat2, feat3], out
    


class Simplified_GCNv3(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim=17, num_heads=8, dropout=0.2):
        super(Simplified_GCNv3, self).__init__()

        # 第一层图卷积（GCNConv）
        self.conv1 = GCNConv(num_node_features, hidden_dim * 2)  # 维度与GAT时相同
        self.bn1 = BatchNorm(hidden_dim * 2)

        # 第二层图卷积（GCNConv）
        self.conv2 = GCNConv(hidden_dim * 2, hidden_dim *  2 * num_heads)  # 维度与GAT时相同
        self.bn2 = BatchNorm(hidden_dim * 2 * num_heads)

        # 第三层图卷积（GCNConv）
        self.conv3 = GCNConv(hidden_dim * 2 * num_heads, hidden_dim* 4)  # 维度与GAT时相同
        self.bn3 = BatchNorm(hidden_dim * 4)

        # 第四层图卷积（GCNConv）
        self.conv4 = GCNConv(hidden_dim * 4, hidden_dim)  # 维度与GAT时相同
        self.bn4 = BatchNorm(hidden_dim)

        # 全局池化层
        self.global_pool = global_mean_pool

        # 定义 Dropout 模块
        self.dropout = nn.Dropout(dropout)

        # 全连接层
        self.fc1 = nn.Linear(hidden_dim, output_dim)  # 全连接层的输入维度要根据num_heads调整

    def get_feat_modules(self):
        """定义获取特征模块的方法，返回包含图卷积层的 ModuleList"""
        feat_m = nn.ModuleList([])

        feat_m.append(self.conv1)  # 第一层图卷积
        feat_m.append(self.conv2)  # 第二层图卷积
        feat_m.append(self.conv3)  # 第三层图卷积
        feat_m.append(self.conv4)  # 第四层图卷积

        feat_m.append(self.fc1)  # 第一层全连接

        return feat_m

    def forward(self, data):
        """
        前向传播。

        参数:
        - data (Batch): 批处理后的图数据。

        返回:
        - out (torch.Tensor): 输出张量，形状为 [batch_size, 17]。
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch  # 获取节点特征、边索引和批处理信息
        
        # 第一层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第二层图卷积 + 批量归一化 + 激活函数 + Dropout
        feat1 = x  # 保留第一层特征
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)

        # 第三层图卷积 + 批量归一化 + 激活函数 + Dropout
        feat2 = x  # 保留第二层特征
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第四层图卷积 + 批量归一化 + 激活函数 + Dropout
        feat3 = x  # 保留第三层特征
        x = self.conv4(x, edge_index)
        x = self.bn4(x)
        x = F.elu(x)
        x = self.dropout(x)

        # 全局池化，将节点特征汇聚为图级特征
        x = self.global_pool(x, batch)  # 形状为 [batch_size, hidden_dim * num_heads]

        # 全连接层 1 + 激活函数 + Dropout
        x = self.fc1(x)
        out = torch.sigmoid(x)
        
        return [feat1, feat2, feat3], out


class Enhanced_large_GCNv2(nn.Module):
    def __init__(self, num_node_features, hidden_dim, output_dim=17, num_heads=8, dropout=0.2):
        super(Enhanced_large_GCNv2, self).__init__()
        
        # 第一层图卷积（GCNConv）
        self.conv1 = GCNConv(num_node_features, hidden_dim * 2)  # 增加隐藏维度
        self.bn1 = BatchNorm(hidden_dim * 2)
        
        # 第二层图卷积（GATConv）使用多头注意力
        self.conv2 = GCNConv(hidden_dim * 2, hidden_dim * 2 * num_heads)  # 增加头数
        self.bn2 = BatchNorm(hidden_dim * 2 * num_heads)
        
        # 第三层图卷积（GCNConv）
        self.conv3 = GCNConv(hidden_dim * 2 * num_heads, hidden_dim * 4)  # 增加维度
        self.bn3 = BatchNorm(hidden_dim * 4)
        
        # 第四层图卷积（GATConv）使用多头注意力
        self.conv4 = GCNConv(hidden_dim * 4, hidden_dim * 4 * num_heads)  # 增加头数
        self.bn4 = BatchNorm(hidden_dim * 4 * num_heads)

        self.conv5 = GCNConv(hidden_dim * 4 * num_heads, hidden_dim * 8)  # 增加维度
        self.bn5 = BatchNorm(hidden_dim * 8)

        self.conv6 = GCNConv(hidden_dim * 8, hidden_dim * 4)  # 增加维度
        self.bn6 = BatchNorm(hidden_dim * 4)
        
        # 变换维度的线性层，适应残差连接
        self.res_fc2 = nn.Linear(hidden_dim * 2, hidden_dim * 2 * num_heads)  # 变换维度（conv2残差）
        self.res_fc4 = nn.Linear(hidden_dim * 4, hidden_dim * 4 * num_heads)  # 变换维度（conv4残差）

        # 全局池化层
        self.global_pool = global_mean_pool

        # 定义 Dropout 模块
        self.dropout = nn.Dropout(dropout)
    
        self.fc1 = nn.Linear(hidden_dim * 4, output_dim)  # 增加全连接层

    def get_feat_modules(self):
        """定义获取特征模块的方法，返回包含图卷积层的 ModuleList"""
        feat_m = nn.ModuleList([])

        feat_m.append(self.conv1)  # 第一层图卷积
        feat_m.append(self.conv2)  # 第二层图卷积
        feat_m.append(self.conv3)  # 第三层图卷积
        feat_m.append(self.conv4)  # 第四层图卷积
        feat_m.append(self.conv5)  # 第五层图卷积
        feat_m.append(self.conv6)  # 第六层图卷积

        feat_m.append(self.fc1)  # 第一层全连接

        return feat_m
        
    def forward(self, data):
        """
        前向传播。

        参数:
        - data (Batch): 批处理后的图数据。

        返回:
        - out (torch.Tensor): 输出张量，形状为 [batch_size, 17]。
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch  # 获取节点特征、边索引和批处理信息
        
        # 第一层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        # 第二层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout + 残差连接
        x_residual = x  # 保存原始输入用于残差连接
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)
        x_residual = self.res_fc2(x_residual)  # 变换残差维度
        x = x + x_residual  # 添加残差连接

        feat1 = x

        # 第三层图卷积 + 批量归一化 + 激活函数 + Dropout
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        # 第四层图卷积（GATConv） + 批量归一化 + 激活函数 + Dropout + 残差连接
        x_residual = x  # 保存原始输入用于残差连接
        x = self.conv4(x, edge_index)
        x = self.bn4(x)
        x = F.elu(x)
        x = self.dropout(x)
        x_residual = self.res_fc4(x_residual)  # 变换残差维度
        x = x + x_residual  # 添加残差连接

        feat2 = x

        x = self.conv5(x, edge_index)
        x = self.bn5(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv6(x, edge_index)
        x = self.bn6(x)
        x = F.relu(x)
        x = self.dropout(x)

        feat3 = x

        # 全局池化，将节点特征汇聚为图级特征
        x = self.global_pool(x, batch)  # 形状为 [batch_size, hidden_dim * 4 * num_heads]

        # 全连接层 1 + 激活函数 + Dropout
        x = self.fc1(x)
        out = torch.sigmoid(x)
        
        return [feat1, feat2, feat3], out